{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e1173893c4f0ea56",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Chunked Pooling\n",
    "This notebooks explains how the chunked pooling can be implemented. First you need to install the requirements: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d02a920f-cde0-4035-9834-49b087aab5cc",
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58a8fbc1e477db48",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Then we load a model which we want to use for the embedding. We choose `jinaai/jina-embeddings-v2-base-en` but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1380abf7acde9517",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/michael/workspace/chunked-pooling/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModel\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "from chunked_pooling import chunked_pooling, chunk_by_sentences\n",
    "\n",
    "# load model and tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cc0c1162797ffb0",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Now we define the text which we want to encode and split it into chunks. The `chunk_by_sentences` function also returns the span annotations. Those specify the number of tokens per chunk which is needed for the chunked pooling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8ef392f3437ef82e",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Chunks:\n",
      "- \"Berlin is the capital and largest city of Germany, both by area and by population.\"\n",
      "- \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"\n",
      "- \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"\n"
     ]
    }
   ],
   "source": [
    "input_text = \"Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"\n",
    "\n",
    "# determine chunks\n",
    "chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)\n",
    "print('Chunks:\\n- \"' + '\"\\n- \"'.join(chunks) + '\"')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ac41fd1f0560da7",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Now we encode the chunks with the traditional and the context-sensitive chunked pooling method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "abe3d93b9e6609b9",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# chunk before\n",
    "embeddings_traditional_chunking = model.encode(chunks)\n",
    "\n",
    "# chunk afterwards (context-sensitive chunked pooling)\n",
    "inputs = tokenizer(input_text, return_tensors='pt')\n",
    "model_output = model(**inputs)\n",
    "embeddings = chunked_pooling(model_output, [span_annotations])[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e84b1b9d48cb6367",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Finally, we compare the similarity of the word \"Berlin\" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "da0cec59a3ece76",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "similarity_new(\"Berlin\", \"Berlin is the capital and largest city of Germany, both by area and by population.\"): 0.849546\n",
      "similarity_trad(\"Berlin\", \"Berlin is the capital and largest city of Germany, both by area and by population.\"): 0.84862185\n",
      "similarity_new(\"Berlin\", \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"): 0.82489026\n",
      "similarity_trad(\"Berlin\", \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"): 0.7084338\n",
      "similarity_new(\"Berlin\", \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"): 0.84980094\n",
      "similarity_trad(\"Berlin\", \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"): 0.7534553\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))\n",
    "\n",
    "berlin_embedding = model.encode('Berlin')\n",
    "\n",
    "for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):\n",
    "    print(f'similarity_new(\"Berlin\", \"{chunk}\"):', cos_sim(berlin_embedding, new_embedding))\n",
    "    print(f'similarity_trad(\"Berlin\", \"{chunk}\"):', cos_sim(berlin_embedding, trad_embeddings))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
